import numpy as np
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import torch

pascal_graph = {
    0: [0],
    1: [1, 2],
    2: [1, 2, 3, 5],
    3: [2, 3, 4],
    4: [3, 4],
    5: [2, 5, 6],
    6: [5, 6],
}

cihp_graph = {
    0: [],
    1: [2, 13],
    2: [1, 13],
    3: [14, 15],
    4: [13],
    5: [6, 7, 9, 10, 11, 12, 14, 15],
    6: [5, 7, 10, 11, 14, 15, 16, 17],
    7: [5, 6, 9, 10, 11, 12, 14, 15],
    8: [16, 17, 18, 19],
    9: [5, 7, 10, 16, 17, 18, 19],
    10: [5, 6, 7, 9, 11, 12, 13, 14, 15, 16, 17],
    11: [5, 6, 7, 10, 13],
    12: [5, 7, 10, 16, 17],
    13: [1, 2, 4, 10, 11],
    14: [3, 5, 6, 7, 10],
    15: [3, 5, 6, 7, 10],
    16: [6, 8, 9, 10, 12, 18],
    17: [6, 8, 9, 10, 12, 19],
    18: [8, 9, 16],
    19: [8, 9, 17],
}

atr_graph = {
    0: [],
    1: [2, 11],
    2: [1, 11],
    3: [11],
    4: [5, 6, 7, 11, 14, 15, 17],
    5: [4, 6, 7, 8, 12, 13],
    6: [4, 5, 7, 8, 9, 10, 12, 13],
    7: [4, 11, 12, 13, 14, 15],
    8: [5, 6],
    9: [6, 12],
    10: [6, 13],
    11: [1, 2, 3, 4, 7, 14, 15, 17],
    12: [5, 6, 7, 9],
    13: [5, 6, 7, 10],
    14: [4, 7, 11, 16],
    15: [4, 7, 11, 16],
    16: [14, 15],
    17: [4, 11],
}

cihp2pascal_adj = np.array(
    [
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
    ]
)

cihp2pascal_nlp_adj = np.array(
    [
        [
            1.0,
            0.35333052,
            0.32727194,
            0.17418084,
            0.18757584,
            0.40608522,
            0.37503981,
            0.35448462,
            0.22598555,
            0.23893579,
            0.33064262,
            0.28923404,
            0.27986573,
            0.4211553,
            0.36915778,
            0.41377746,
            0.32485771,
            0.37248222,
            0.36865639,
            0.41500332,
        ],
        [
            0.39615879,
            0.46201529,
            0.52321467,
            0.30826114,
            0.25669527,
            0.54747773,
            0.3670523,
            0.3901983,
            0.27519473,
            0.3433325,
            0.52728509,
            0.32771333,
            0.34819325,
            0.63882953,
            0.68042925,
            0.69368576,
            0.63395791,
            0.65344337,
            0.59538781,
            0.6071375,
        ],
        [
            0.16373166,
            0.21663339,
            0.3053872,
            0.28377612,
            0.1372435,
            0.4448808,
            0.29479995,
            0.31092595,
            0.22703953,
            0.33983576,
            0.75778818,
            0.2619818,
            0.37069392,
            0.35184867,
            0.49877512,
            0.49979437,
            0.51853277,
            0.52517541,
            0.32517741,
            0.32377309,
        ],
        [
            0.32687232,
            0.38482461,
            0.37693463,
            0.41610834,
            0.20415749,
            0.76749079,
            0.35139853,
            0.3787411,
            0.28411737,
            0.35155421,
            0.58792618,
            0.31141718,
            0.40585111,
            0.51189218,
            0.82042737,
            0.8342413,
            0.70732188,
            0.72752501,
            0.60327325,
            0.61431337,
        ],
        [
            0.34069369,
            0.34817292,
            0.37525998,
            0.36497069,
            0.17841617,
            0.69746208,
            0.31731463,
            0.34628951,
            0.25167277,
            0.32072379,
            0.56711286,
            0.24894776,
            0.37000453,
            0.52600859,
            0.82483993,
            0.84966274,
            0.7033991,
            0.73449378,
            0.56649608,
            0.58888791,
        ],
        [
            0.28477487,
            0.35139564,
            0.42742352,
            0.41664321,
            0.20004676,
            0.78566833,
            0.42237487,
            0.41048549,
            0.37933812,
            0.46542516,
            0.62444759,
            0.3274493,
            0.49466009,
            0.49314658,
            0.71244233,
            0.71497003,
            0.8234787,
            0.83566589,
            0.62597135,
            0.62626812,
        ],
        [
            0.3011378,
            0.31775977,
            0.42922647,
            0.36896257,
            0.17597556,
            0.72214655,
            0.39162804,
            0.38137872,
            0.34980296,
            0.43818419,
            0.60879174,
            0.26762545,
            0.46271161,
            0.51150476,
            0.72318109,
            0.73678399,
            0.82620388,
            0.84942166,
            0.5943811,
            0.60607602,
        ],
    ]
)

pascal2atr_nlp_adj = np.array(
    [
        [
            1.0,
            0.35333052,
            0.32727194,
            0.18757584,
            0.40608522,
            0.27986573,
            0.23893579,
            0.27600672,
            0.30964391,
            0.36865639,
            0.41500332,
            0.4211553,
            0.32485771,
            0.37248222,
            0.36915778,
            0.41377746,
            0.32006291,
            0.28923404,
        ],
        [
            0.39615879,
            0.46201529,
            0.52321467,
            0.25669527,
            0.54747773,
            0.34819325,
            0.3433325,
            0.26603942,
            0.45162929,
            0.59538781,
            0.6071375,
            0.63882953,
            0.63395791,
            0.65344337,
            0.68042925,
            0.69368576,
            0.44354613,
            0.32771333,
        ],
        [
            0.16373166,
            0.21663339,
            0.3053872,
            0.1372435,
            0.4448808,
            0.37069392,
            0.33983576,
            0.26563416,
            0.35443504,
            0.32517741,
            0.32377309,
            0.35184867,
            0.51853277,
            0.52517541,
            0.49877512,
            0.49979437,
            0.21750868,
            0.2619818,
        ],
        [
            0.32687232,
            0.38482461,
            0.37693463,
            0.20415749,
            0.76749079,
            0.40585111,
            0.35155421,
            0.28271333,
            0.52684576,
            0.60327325,
            0.61431337,
            0.51189218,
            0.70732188,
            0.72752501,
            0.82042737,
            0.8342413,
            0.40137029,
            0.31141718,
        ],
        [
            0.34069369,
            0.34817292,
            0.37525998,
            0.17841617,
            0.69746208,
            0.37000453,
            0.32072379,
            0.27268885,
            0.47426719,
            0.56649608,
            0.58888791,
            0.52600859,
            0.7033991,
            0.73449378,
            0.82483993,
            0.84966274,
            0.37830796,
            0.24894776,
        ],
        [
            0.28477487,
            0.35139564,
            0.42742352,
            0.20004676,
            0.78566833,
            0.49466009,
            0.46542516,
            0.32662614,
            0.55780359,
            0.62597135,
            0.62626812,
            0.49314658,
            0.8234787,
            0.83566589,
            0.71244233,
            0.71497003,
            0.41223219,
            0.3274493,
        ],
        [
            0.3011378,
            0.31775977,
            0.42922647,
            0.17597556,
            0.72214655,
            0.46271161,
            0.43818419,
            0.3192333,
            0.50979216,
            0.5943811,
            0.60607602,
            0.51150476,
            0.82620388,
            0.84942166,
            0.72318109,
            0.73678399,
            0.39259827,
            0.26762545,
        ],
    ]
)

cihp2atr_nlp_adj = np.array(
    [
        [
            1.0,
            0.35333052,
            0.32727194,
            0.18757584,
            0.40608522,
            0.27986573,
            0.23893579,
            0.27600672,
            0.30964391,
            0.36865639,
            0.41500332,
            0.4211553,
            0.32485771,
            0.37248222,
            0.36915778,
            0.41377746,
            0.32006291,
            0.28923404,
        ],
        [
            0.35333052,
            1.0,
            0.39206695,
            0.42143438,
            0.4736689,
            0.47139544,
            0.51999208,
            0.38354847,
            0.45628529,
            0.46514124,
            0.50083501,
            0.4310595,
            0.39371443,
            0.4319752,
            0.42938598,
            0.46384034,
            0.44833757,
            0.6153155,
        ],
        [
            0.32727194,
            0.39206695,
            1.0,
            0.32836702,
            0.52603065,
            0.39543695,
            0.3622627,
            0.43575346,
            0.33866223,
            0.45202552,
            0.48421,
            0.53669903,
            0.47266611,
            0.50925436,
            0.42286557,
            0.45403656,
            0.37221304,
            0.40999322,
        ],
        [
            0.17418084,
            0.46892601,
            0.25774838,
            0.31816231,
            0.39330317,
            0.34218382,
            0.48253904,
            0.22084125,
            0.41335728,
            0.52437572,
            0.5191713,
            0.33576117,
            0.44230914,
            0.44250678,
            0.44330833,
            0.43887264,
            0.50693611,
            0.39278795,
        ],
        [
            0.18757584,
            0.42143438,
            0.32836702,
            1.0,
            0.35030067,
            0.30110947,
            0.41055555,
            0.34338879,
            0.34336307,
            0.37704433,
            0.38810141,
            0.34702081,
            0.24171562,
            0.25433078,
            0.24696241,
            0.2570884,
            0.4465962,
            0.45263213,
        ],
        [
            0.40608522,
            0.4736689,
            0.52603065,
            0.35030067,
            1.0,
            0.54372584,
            0.58300258,
            0.56674191,
            0.555266,
            0.66599594,
            0.68567555,
            0.55716359,
            0.62997328,
            0.65638548,
            0.61219615,
            0.63183318,
            0.54464151,
            0.44293752,
        ],
        [
            0.37503981,
            0.50675565,
            0.4761106,
            0.37561813,
            0.60419403,
            0.77912403,
            0.64595517,
            0.85939662,
            0.46037144,
            0.52348817,
            0.55875094,
            0.37741886,
            0.455671,
            0.49434392,
            0.38479954,
            0.41804074,
            0.47285709,
            0.57236283,
        ],
        [
            0.35448462,
            0.50576632,
            0.51030446,
            0.35841033,
            0.55106903,
            0.50257274,
            0.52591451,
            0.4283053,
            0.39991808,
            0.42327211,
            0.42853819,
            0.42071825,
            0.41240559,
            0.42259136,
            0.38125352,
            0.3868255,
            0.47604934,
            0.51811717,
        ],
        [
            0.22598555,
            0.5053299,
            0.36301185,
            0.38002282,
            0.49700941,
            0.45625243,
            0.62876479,
            0.4112051,
            0.33944371,
            0.48322639,
            0.50318714,
            0.29207815,
            0.38801966,
            0.41119094,
            0.29199072,
            0.31021029,
            0.41594871,
            0.54961962,
        ],
        [
            0.23893579,
            0.51999208,
            0.3622627,
            0.41055555,
            0.58300258,
            0.68874251,
            1.0,
            0.56977937,
            0.49918447,
            0.48484363,
            0.51615925,
            0.41222306,
            0.49535971,
            0.53134951,
            0.3807616,
            0.41050298,
            0.48675801,
            0.51112664,
        ],
        [
            0.33064262,
            0.306412,
            0.60679935,
            0.25592294,
            0.58738706,
            0.40379627,
            0.39679161,
            0.33618385,
            0.39235148,
            0.45474013,
            0.4648476,
            0.59306762,
            0.58976007,
            0.60778661,
            0.55400397,
            0.56551297,
            0.3698029,
            0.33860535,
        ],
        [
            0.28923404,
            0.6153155,
            0.40999322,
            0.45263213,
            0.44293752,
            0.60359359,
            0.51112664,
            0.46578181,
            0.45656936,
            0.38142307,
            0.38525582,
            0.33327223,
            0.35360175,
            0.36156453,
            0.3384992,
            0.34261229,
            0.49297863,
            1.0,
        ],
        [
            0.27986573,
            0.47139544,
            0.39543695,
            0.30110947,
            0.54372584,
            1.0,
            0.68874251,
            0.67765588,
            0.48690078,
            0.44010641,
            0.44921156,
            0.32321099,
            0.48311542,
            0.4982002,
            0.39378102,
            0.40297733,
            0.45309735,
            0.60359359,
        ],
        [
            0.4211553,
            0.4310595,
            0.53669903,
            0.34702081,
            0.55716359,
            0.32321099,
            0.41222306,
            0.25721705,
            0.36633509,
            0.5397475,
            0.56429928,
            1.0,
            0.55796926,
            0.58842844,
            0.57930828,
            0.60410597,
            0.41615326,
            0.33327223,
        ],
        [
            0.36915778,
            0.42938598,
            0.42286557,
            0.24696241,
            0.61219615,
            0.39378102,
            0.3807616,
            0.28089866,
            0.48450394,
            0.77400821,
            0.68813814,
            0.57930828,
            0.8856886,
            0.81673412,
            1.0,
            0.92279623,
            0.46969152,
            0.3384992,
        ],
        [
            0.41377746,
            0.46384034,
            0.45403656,
            0.2570884,
            0.63183318,
            0.40297733,
            0.41050298,
            0.332879,
            0.48799542,
            0.69231828,
            0.77015091,
            0.60410597,
            0.79788484,
            0.88232104,
            0.92279623,
            1.0,
            0.45685017,
            0.34261229,
        ],
        [
            0.32485771,
            0.39371443,
            0.47266611,
            0.24171562,
            0.62997328,
            0.48311542,
            0.49535971,
            0.32477932,
            0.51486622,
            0.79353556,
            0.69768738,
            0.55796926,
            1.0,
            0.92373745,
            0.8856886,
            0.79788484,
            0.47883134,
            0.35360175,
        ],
        [
            0.37248222,
            0.4319752,
            0.50925436,
            0.25433078,
            0.65638548,
            0.4982002,
            0.53134951,
            0.38057074,
            0.52403969,
            0.72035243,
            0.78711147,
            0.58842844,
            0.92373745,
            1.0,
            0.81673412,
            0.88232104,
            0.47109935,
            0.36156453,
        ],
        [
            0.36865639,
            0.46514124,
            0.45202552,
            0.37704433,
            0.66599594,
            0.44010641,
            0.48484363,
            0.39636574,
            0.50175258,
            1.0,
            0.91320249,
            0.5397475,
            0.79353556,
            0.72035243,
            0.77400821,
            0.69231828,
            0.59087008,
            0.38142307,
        ],
        [
            0.41500332,
            0.50083501,
            0.48421,
            0.38810141,
            0.68567555,
            0.44921156,
            0.51615925,
            0.45156472,
            0.50438158,
            0.91320249,
            1.0,
            0.56429928,
            0.69768738,
            0.78711147,
            0.68813814,
            0.77015091,
            0.57698754,
            0.38525582,
        ],
    ]
)


def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()


def preprocess_adj(adj):
    """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
    adj = nx.adjacency_matrix(
        nx.from_dict_of_lists(adj)
    )  # return a adjacency matrix of adj ( type is numpy)
    adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0]))  #
    # return sparse_to_tuple(adj_normalized)
    return adj_normalized.todense()


def row_norm(inputs):
    outputs = []
    for x in inputs:
        xsum = x.sum()
        x = x / xsum
        outputs.append(x)
    return outputs


def normalize_adj_torch(adj):
    # print(adj.size())
    if len(adj.size()) == 4:
        new_r = torch.zeros(adj.size()).type_as(adj)
        for i in range(adj.size(1)):
            adj_item = adj[0, i]
            rowsum = adj_item.sum(1)
            d_inv_sqrt = rowsum.pow_(-0.5)
            d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
            d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
            r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj_item), d_mat_inv_sqrt)
            new_r[0, i, ...] = r
        return new_r
    rowsum = adj.sum(1)
    d_inv_sqrt = rowsum.pow_(-0.5)
    d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
    d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
    r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)
    return r


# def row_norm(adj):


if __name__ == "__main__":
    a = row_norm(cihp2pascal_adj)
    print(a)
    print(cihp2pascal_adj)
    # print(a.shape)
